"""
This module provides result classes for multi-trajectory solvers.
Note that single trajectories are described by regular `Result` objects from the
`qutip.solver.result` module.
"""
# Required for Sphinx to follow autodoc_type_aliases
from __future__ import annotations

from typing import TypedDict
from ..core.numpy_backend import np

from copy import copy

from .result import _BaseResult
from ..core import qzero_like

__all__ = [
    "MultiTrajResult",
    "McResult",
    "NmmcResult",
]


class MultiTrajResultOptions(TypedDict):
    store_states: bool
    store_final_state: bool
    keep_runs_results: bool


class MultiTrajResult(_BaseResult):
    """
    Base class for storing results for solver using multiple trajectories.

    Parameters
    ----------
    e_ops : :obj:`.Qobj`, :obj:`.QobjEvo`, function or list or dict of these
        The ``e_ops`` parameter defines the set of values to record at
        each time step ``t``. If an element is a :obj:`.Qobj` or
        :obj:`.QobjEvo` the value recorded is the expectation value of that
        operator given the state at ``t``. If the element is a function, ``f``,
        the value recorded is ``f(t, state)``.

        The values are recorded in the ``.expect`` attribute of this result
        object. ``.expect`` is a list, where each item contains the values
        of the corresponding ``e_op``.

        Function ``e_ops`` must return a number so the average can be computed.

    options : dict
        The options for this result class.

    solver : str or None
        The name of the solver generating these results.

    stats : dict or None
        The stats generated by the solver while producing these results. Note
        that the solver may update the stats directly while producing results.

    kw : dict
        Additional parameters specific to a result sub-class.

    Attributes
    ----------
    times : list
        A list of the times at which the expectation values and states were
        recorded.

    average_states : list of :obj:`.Qobj`
        The state at each time ``t`` (if the recording of the state was
        requested) averaged over all trajectories as a density matrix.

    runs_states : list of list of :obj:`.Qobj`
        The state for each trajectory and each time ``t`` (if the recording of
        the states and trajectories was requested)

    average_final_state : :obj:`.Qobj`:
        The final state (if the recording of the final state was requested)
        averaged over all trajectories as a density matrix.

    runs_final_states : list of :obj:`.Qobj`
        The final state for each trajectory (if the recording of the final
        state and trajectories was requested).

    average_expect : list of array of expectation values
        A list containing the values of each ``e_op`` averaged over each
        trajectories. The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    std_expect : list of array of expectation values
        A list containing the standard derivation of each ``e_op`` over each
        trajectories. The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    runs_expect : list of array of expectation values
        A list containing the values of each ``e_op`` for each trajectories.
        The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given. Only available if the
        storing of trajectories was requested.

        The order of the elements is ``runs_expect[e_ops][trajectory][time]``.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    average_e_data : dict
        A dictionary containing the values of each ``e_op`` averaged over each
        trajectories. If the ``e_ops`` were supplied as a dictionary, the keys
        are the same as in that dictionary. Otherwise the keys are the index of
        the ``e_op`` in the ``.expect`` list.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    std_e_data : dict
        A dictionary containing the standard derivation of each ``e_op`` over
        each trajectories. If the ``e_ops`` were supplied as a dictionary, the
        keys are the same as in that dictionary. Otherwise the keys are the
        index of the ``e_op`` in the ``.expect`` list.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    runs_e_data : dict
        A dictionary containing the values of each ``e_op`` for each
        trajectories. If the ``e_ops`` were supplied as a dictionary, the keys
        are the same as in that dictionary. Otherwise the keys are the index of
        the ``e_op`` in the ``.expect`` list. Only available if the storing
        of trajectories was requested.

        The order of the elements is ``runs_expect[e_ops][trajectory][time]``.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    runs_weights : list
        For each trajectory, the weight with which that trajectory enters
        averages.

    deterministic_weights : list
        For each deterministic trajectory, (when using improved_sampling)
        the weight with which that trajectory enters averages.

    solver : str or None
        The name of the solver generating these results.

    num_trajectories: int
        Number of trajectories computed.

    seeds: list of SeedSequence
        The seeds used to compute each trajectories.

    trajectories: list of Result
        If the option `keep_runs_results` is set, a list of all trajectories.

    deterministic_trajectories: list of Result
        A list of the no-jump trajectories if the option ``improved_sampling``
        is set.

    stats : dict or None
        The stats generated by the solver while producing these results.

    options : :obj:`~SolverResultsOptions`
        The options for this result class.
    """

    options: MultiTrajResultOptions

    def __init__(
        self, e_ops, options: MultiTrajResultOptions, *,
        solver=None, stats=None, **kw,
    ):
        super().__init__(options, solver=solver, stats=stats)
        self._raw_ops = self._e_ops_to_dict(e_ops)

        self.trajectories = []
        self.deterministic_trajectories = []
        self.num_trajectories = 0
        self.seeds = []

        self._average_e_data = {}
        self._std_e_data = {}
        if self.options["keep_runs_results"]:
            self.runs_e_data = {k: [] for k in self._raw_ops}
        else:
            self.runs_e_data = {}

        # Will be initialized at the first trajectory
        self.times = None
        self.e_ops = None

        # We separate all sums into terms of trajectories with specified
        # deterministic trajectories weight (_det) or without (_rel). They will
        # be initialized when the first trajectory of the respective type is
        # added.
        self._sum_rel = None
        self._sum_det = None
        # Needed for merging results
        self._trajectories_weight_info = []
        self._deterministic_weight_info = []

        self._post_init(**kw)

    @property
    def _store_average_density_matrices(self) -> bool:
        return (
            self.options["store_states"]
            or (self.options["store_states"] is None and self._raw_ops == {})
        ) and not self.options["keep_runs_results"]

    @property
    def _store_final_density_matrix(self) -> bool:
        return (
            self.options["store_final_state"]
            and not self._store_average_density_matrices
            and not self.options["keep_runs_results"]
        )

    def _add_first_traj(self, trajectory):
        """
        Read the first trajectory, intitializing needed data.
        """
        self.times = trajectory.times
        self.e_ops = trajectory.e_ops

    def _store_trajectory(self, trajectory, *, abs=None, rel=None):
        if abs is None:
            self.trajectories.append(trajectory)

    def _reduce_states(self, trajectory, *, abs=None, rel=None):
        if abs is not None:
            self._sum_det.reduce_states(trajectory, abs)
        else:
            self._sum_rel.reduce_states(trajectory, rel)

    def _reduce_final_state(self, trajectory, *, abs=None, rel=None):
        if abs is not None:
            self._sum_det.reduce_final_state(trajectory, abs)
        else:
            self._sum_rel.reduce_final_state(trajectory, rel)

    def _reduce_expect(self, trajectory, *, abs=None, rel=None):
        """
        Compute the average of the expectation values and store it in it's
        multiple formats.
        """
        if abs is not None:
            self._sum_det.reduce_expect(trajectory, abs)
        else:
            self._sum_rel.reduce_expect(trajectory, rel)

            if self.runs_e_data:
                for k in self._raw_ops:
                    self.runs_e_data[k].append(trajectory.e_data[k])

    def _create_e_data(self):
        for i, k in enumerate(self._raw_ops):
            avg = 0
            avg2 = 0
            if self._sum_det:
                avg += self._sum_det.sum_expect[i]
                avg2 += self._sum_det.sum2_expect[i]
            if self._sum_rel:
                avg += (
                    self._sum_rel.sum_expect[i] / self.num_trajectories
                )
                avg2 += (
                    self._sum_rel.sum2_expect[i] / self.num_trajectories
                )

            self._average_e_data[k] = list(avg)
            # mean(expect**2) - mean(expect)**2 can something be very small
            # negative (-1e-15) which raise an error for float sqrt.
            self._std_e_data[k] = list(np.sqrt(np.abs(avg2 - np.abs(avg**2))))

    def _increment_traj(self, trajectory, *, abs=None, rel=None):
        if self.num_trajectories == 0 and not self._deterministic_weight_info:
            self._add_first_traj(trajectory)

        if abs is not None:
            if self._sum_det is None:
                self._sum_det = _TrajectorySum(
                    trajectory,
                    self._store_average_density_matrices,
                    self._store_final_density_matrix)
        else:
            self.num_trajectories += 1
            if self._sum_rel is None:
                self._sum_rel = _TrajectorySum(
                    trajectory,
                    self._store_average_density_matrices,
                    self._store_final_density_matrix)

    def _no_end(self):
        """
        Remaining number of trajectories needed to finish cannot be determined
        by this object.
        """
        return np.inf

    def _fixed_end(self):
        """
        Finish at a known number of trajectories.
        """
        ntraj_left = self._target_ntraj - self.num_trajectories
        if ntraj_left == 0:
            self.stats["end_condition"] = "ntraj reached"
        return ntraj_left

    def _average_computer(self):
        avg = np.array(self._sum_rel.sum_expect) / self.num_trajectories
        avg2 = np.array(self._sum_rel.sum2_expect) / self.num_trajectories
        return avg, avg2

    def _target_tolerance_end(self):
        """
        Compute the error on the expectation values using jackknife resampling.
        Return the approximate number of trajectories needed to have this
        error within the tolerance fot all e_ops and times.
        """
        if self.num_trajectories >= self._target_ntraj:
            # First make sure that "ntraj" setting is always respected
            self.stats["end_condition"] = "ntraj reached"
            return 0

        if self.num_trajectories <= 1:
            return np.inf
        avg, avg2 = self._average_computer()
        target = np.array(
            [
                atol + rtol * mean
                for mean, (atol, rtol) in zip(avg, self._target_tols)
            ]
        )

        one = np.array(1)
        if sum(self._deterministic_weight_info):
            # We do not include deterministic traj. in this calculation.
            # When there is a deterministic trajectory, the weights don't add
            # up to one. We have to consider that as follows:
            # err = (std * <w>**2 / (N-1)) ** 0.5
            # avg = <x * w>
            # avg2 = <x**2 * w>
            # std * <w>**2 = (<x**2> - <x>**2) * <w>**2
            #              = avg2 * <w> - avg**2
            # and "<w>" is one minus the sum of all deterministic trajectories
            # weights
            one = one - sum(self._deterministic_weight_info)

        std = avg2 * one - abs(avg)**2
        target_ntraj = np.max(std / target**2) + 1
        self._estimated_ntraj = min(target_ntraj - self.num_trajectories,
                                    self._target_ntraj - self.num_trajectories)
        if self._estimated_ntraj <= 0:
            self.stats["end_condition"] = "target tolerance reached"
        return self._estimated_ntraj

    def _post_init(self):
        self._target_ntraj = None
        self._target_tols = None
        self._early_finish_check = self._no_end

        self.add_processor(self._increment_traj)
        store_trajectory = self.options["keep_runs_results"]
        if store_trajectory:
            self.add_processor(self._store_trajectory)
        if self._store_average_density_matrices:
            self.add_processor(self._reduce_states)
        if self._store_final_density_matrix:
            self.add_processor(self._reduce_final_state)
        if self._raw_ops:
            self.add_processor(self._reduce_expect)
        # self.add_processor(self._store_weight_info)

        self.stats["end_condition"] = "unknown"

    def add_deterministic(self, trajectory, weight):
        """
        Add a trajectory that was not randomly generated.
        The weight provided here is the exact weight that will be used for this
        trajectory in all averages.

        Parameters
        ----------
        trajectory : :class:`Result`
            The result of the simulation of the deterministic trajectory

        weight : float
            Number (usually between 0 and 1), exact weight of this trajectory
        """
        for op in self._state_processors:
            op(trajectory, abs=weight)

        self._deterministic_weight_info.append(weight)
        self.deterministic_trajectories.append(trajectory)

    def add(self, trajectory_info):
        """
        Add a trajectory to the evolution.

        Trajectories can be saved or average canbe extracted depending on the
        options ``keep_runs_results``.

        Parameters
        ----------
        trajectory_info : tuple of seed and trajectory
            - seed: int, SeedSequence
              Seed used to generate the trajectory.
            - trajectory : :class:`Result`
              Run result for one evolution over the times.
            - *weight: float, optional
              Relative weight of the trajectory.

        Returns
        -------
        remaing_traj : number
            Return the number of trajectories still needed to reach the target
            tolerance. If no tolerance is provided, return infinity.
        """
        seed, trajectory, *weight = trajectory_info
        weight = weight[0] if weight else 1

        self.seeds.append(seed)
        self._trajectories_weight_info.append(weight)

        for op in self._state_processors:
            op(trajectory, rel=weight)

        return self._early_finish_check()

    def add_end_condition(self, ntraj, target_tol=None):
        """
        Set the condition to stop the computing trajectories when the certain
        condition are fullfilled.
        Supported end condition for multi trajectories computation are:

        - Reaching a number of trajectories.
        - Error bar on the expectation values reach smaller than a given
          tolerance.

        Parameters
        ----------
        ntraj : int
            Number of trajectories expected.

        target_tol : float, array_like, [optional]
            Target tolerance of the evolution. The evolution will compute
            trajectories until the error on the expectation values is lower
            than this tolerance. The error is computed using jackknife
            resampling. ``target_tol`` can be an absolute tolerance, a pair of
            absolute and relative tolerance, in that order. Lastly, it can be a
            list of pairs of (atol, rtol) for each e_ops.

            Error estimation is done with jackknife resampling.
        """
        self._target_ntraj = ntraj
        self.stats["end_condition"] = "timeout"

        if target_tol is None:
            self._early_finish_check = self._fixed_end
            return

        num_e_ops = len(self._raw_ops)

        if not num_e_ops:
            raise ValueError("Cannot target a tolerance without e_ops")

        self._estimated_ntraj = ntraj

        targets = np.array(target_tol)
        if targets.ndim == 0:
            self._target_tols = np.array([(target_tol, 0.0)] * num_e_ops)
        elif targets.shape == (2,):
            self._target_tols = np.ones((num_e_ops, 2)) * targets
        elif targets.shape == (num_e_ops, 2):
            self._target_tols = targets
        else:
            raise ValueError(
                "target_tol must be a number, a pair of (atol, "
                "rtol) or a list of (atol, rtol) for each e_ops"
            )

        self._early_finish_check = self._target_tolerance_end

    @property
    def runs_states(self):
        """
        States of every runs as ``states[run][t]``.
        """
        if self.trajectories and self.trajectories[0].states:
            return [traj.states for traj in self.trajectories]
        else:
            return None

    @property
    def average_states(self):
        """
        States averages as density matrices.
        """

        trajectory_states_available = (self.trajectories and
                                       self.trajectories[0].states)
        need_to_reduce_states = False
        if self._sum_det and not self._sum_det.sum_states:
            if not trajectory_states_available:
                return None
            self._sum_det._initialize_sum_states(self.trajectories[0])
            need_to_reduce_states = True
        if self._sum_rel and not self._sum_rel.sum_states:
            if not trajectory_states_available:
                return None
            self._sum_rel._initialize_sum_states(self.trajectories[0])
            need_to_reduce_states = True
        if need_to_reduce_states:
            for trajectory, weight in zip(
                self.deterministic_trajectories,
                self._deterministic_weight_info
            ):
                self._reduce_states(trajectory, abs=weight)
            for trajectory, weight in zip(
                self.trajectories,
                self._trajectories_weight_info
            ):
                self._reduce_states(trajectory, rel=weight)

        if self._sum_det and self._sum_rel:
            return [a + r / self.num_trajectories for a, r in zip(
                self._sum_det.sum_states, self._sum_rel.sum_states)
            ]
        if self._sum_rel:
            return [r / self.num_trajectories
                    for r in self._sum_rel.sum_states]
        return self._sum_det.sum_states

    @property
    def states(self):
        """
        Runs final states if available, average otherwise.
        """
        return self.runs_states or self.average_states

    @property
    def runs_final_states(self):
        """
        Last states of each trajectories.
        """
        if self.trajectories and self.trajectories[0].final_state:
            return [traj.final_state for traj in self.trajectories]
        else:
            return None

    @property
    def average_final_state(self):
        """
        Last states of each trajectories averaged into a density matrix.
        """
        trajectory_states_available = (self.trajectories and
                                       self.trajectories[0].final_state)
        states = self.average_states
        need_to_reduce_states = False
        if self._sum_det and not self._sum_det.sum_final_state:
            if not (trajectory_states_available or states):
                return None
            need_to_reduce_states = True

        if self._sum_rel and not self._sum_rel.sum_final_state:
            if not (trajectory_states_available or states):
                return None
            need_to_reduce_states = True

        if need_to_reduce_states and states:
            return states[-1]
        elif need_to_reduce_states:
            if self._sum_det:
                self._sum_det._initialize_sum_finalstate(self.trajectories[0])
            if self._sum_rel:
                self._sum_rel._initialize_sum_finalstate(self.trajectories[0])
            for trajectory, weight in zip(
                self.deterministic_trajectories,
                self._deterministic_weight_info
            ):
                self._reduce_final_state(trajectory, abs=weight)
            for trajectory, weight in zip(
                self.trajectories,
                self._trajectories_weight_info
            ):
                self._reduce_final_state(trajectory, rel=weight)

        if self._sum_det and self._sum_rel:
            return (self._sum_det.sum_final_state +
                    self._sum_rel.sum_final_state / self.num_trajectories)
        if self._sum_rel:
            return self._sum_rel.sum_final_state / self.num_trajectories
        return self._sum_det.sum_final_state

    @property
    def final_state(self):
        """
        Runs final states if available, average otherwise.
        """
        return self.runs_final_states or self.average_final_state

    @property
    def average_e_data(self):
        if not self._average_e_data:
            self._create_e_data()
        return self._average_e_data

    @property
    def std_e_data(self):
        if not self._std_e_data:
            self._create_e_data()
        return self._std_e_data

    @property
    def average_expect(self):
        return [np.array(val) for val in self.average_e_data.values()]

    @property
    def std_expect(self):
        return [np.array(val) for val in self.std_e_data.values()]

    @property
    def runs_expect(self):
        return [np.array(val) for val in self.runs_e_data.values()]

    @property
    def expect(self):
        return [np.array(val) for val in self.e_data.values()]

    @property
    def e_data(self):
        return self.runs_e_data or self.average_e_data

    @property
    def deterministic_weights(self):
        return self._deterministic_weight_info.copy()

    @property
    def runs_weights(self):
        return [
            w / self.num_trajectories
            for w in self._trajectories_weight_info
        ]

    def steady_state(self, N=0):
        """
        Average the states of the last ``N`` times of every runs as a density
        matrix. Should converge to the steady state in the right circumstances.

        Parameters
        ----------
        N : int [optional]
            Number of states from the end of ``tlist`` to average. Per default
            all states will be averaged.
        """
        N = int(N) or len(self.times)
        N = len(self.times) if N > len(self.times) else N
        states = self.average_states
        if states is not None:
            return sum(states[-N:]) / N
        else:
            return None

    def __repr__(self):
        lines = [
            f"<{self.__class__.__name__}",
            f"  Solver: {self.solver}",
        ]
        if self.stats:
            lines.append("  Solver stats:")
            lines.extend(f"    {k}: {v!r}" for k, v in self.stats.items())
        if self.times:
            lines.append(
                f"  Time interval: [{self.times[0]}, {self.times[-1]}]"
                f" ({len(self.times)} steps)"
            )
        lines.append(f"  Number of e_ops: {len(self.e_data)}")
        if self.states:
            lines.append("  States saved.")
        elif self.final_state is not None:
            lines.append("  Final state saved.")
        else:
            lines.append("  State not saved.")
        lines.append(f"  Number of trajectories: {self.num_trajectories}")
        if self.trajectories:
            lines.append("  Trajectories saved.")
        else:
            lines.append("  Trajectories not saved.")
        lines.append(">")
        return "\n".join(lines)

    def merge(self, other, p=None):
        r"""
        Merges two multi-trajectory results.

        If this result represent an ensemble :math:`\rho`, and `other`
        represents an ensemble :math:`\rho'`, then the merged result
        represents the ensemble

        .. math::
            \rho_{\mathrm{merge}} = p \rho + (1 - p) \rho'

        where p is a parameter between 0 and 1. Its default value is
        :math:`p_{\textrm{def}} = N / (N + N')`, N and N' being the number of
        trajectories in the two result objects.

        Parameters
        ----------
        other : MultiTrajResult
            The multi-trajectory result to merge with this one
        p : float [optional]
            The relative weight of this result in the combination. By default,
            will be chosen such that all trajectories contribute equally
            to the merged result.
        """
        if not isinstance(other, MultiTrajResult):
            return NotImplemented
        if self._raw_ops != other._raw_ops:
            raise ValueError("Shared `e_ops` is required to merge results")
        if self.times != other.times:
            raise ValueError("Shared `times` are is required to merge results")

        new = self.__class__(
            self._raw_ops, self.options, solver=self.solver, stats=self.stats
        )
        new.times = self.times
        new.e_ops = self.e_ops

        if bool(self.trajectories) != bool(other.trajectories):
            # ensure the states are reduced.
            if self.trajectories:
                self.average_states
                self.average_final_state
            else:
                other.average_states
                other.average_final_state

        new.num_trajectories = self.num_trajectories + other.num_trajectories
        new.seeds = self.seeds + other.seeds

        p_equal = self.num_trajectories / new.num_trajectories
        if p is None:
            p = self.num_trajectories / new.num_trajectories

        new._deterministic_weight_info = [
            w * p for w in self._deterministic_weight_info
        ] + [
            w * (1 - p) for w in other._deterministic_weight_info
        ]
        new._trajectories_weight_info = [
            w * p / p_equal for w in self._trajectories_weight_info
        ] + [
            w * (1 - p) / (1 - p_equal)
            for w in other._trajectories_weight_info
        ]

        if self.trajectories and other.trajectories:
            new.deterministic_trajectories = (
                self.deterministic_trajectories
                + other.deterministic_trajectories
            )
            new.trajectories = self.trajectories + other.trajectories
        else:
            new.trajectories = []
            new.options["keep_runs_results"] = False
            new.runs_e_data = {}

        self_states = self.options["store_states"]
        self_fstate = self.options["store_final_state"]
        other_states = other.options["store_states"]
        other_fstate = other.options["store_final_state"]

        new.options["store_states"] = self_states and other_states

        new.options["store_final_state"] = (
            (self_fstate or self_states) and (other_fstate or other_states)
        )

        new._sum_det = _TrajectorySum.merge(
            self._sum_det, p, other._sum_det, 1 - p)
        new._sum_rel = _TrajectorySum.merge(
            self._sum_rel, p / p_equal,
            other._sum_rel, (1 - p) / (1 - p_equal))

        if self.runs_e_data and other.runs_e_data:
            for k in self._raw_ops:
                new.runs_e_data[k] = self.runs_e_data[k] + other.runs_e_data[k]

        new.stats["run time"] += other.stats["run time"]
        new.stats["end_condition"] = "Merged results"

        return new

    def __add__(self, other):
        return self.merge(other, p=None)


class _TrajectorySum:
    """
    Keeps running sums of expectation values, and (if requested) states and
    final states, over a set of trajectories as they are added one-by-one.
    This is used in the `MultiTrajResult` class, which needs to keep track of
    several sums of this type.

    Parameters
    ----------
    example_trajectory : :obj:`.Result`
        An example trajectory with expectation values and states of the same
        shape like for the trajectories that will be added later. The data is
        only used for initializing arrays in the correct shape and otherwise
        ignored.

    store_states : bool
        Whether the states of the trajectories will be summed.

    store_final_state : bool
        Whether the final states of the trajectories will be summed.
    """
    def __init__(self, example_trajectory, store_states, store_final_state):
        if example_trajectory.states and store_states:
            self._initialize_sum_states(example_trajectory)
        else:
            self.sum_states = None

        if example_trajectory.final_state and store_final_state:
            self._initialize_sum_finalstate(example_trajectory)
        else:
            self.sum_final_state = None

        self.sum_expect = [
            np.zeros_like(expect) for expect in example_trajectory.expect
        ]
        self.sum2_expect = [
            np.zeros_like(expect) for expect in example_trajectory.expect
        ]

    def _initialize_sum_states(self, example_trajectory):
        self.sum_states = [
            qzero_like(_to_dm(state)) for state in example_trajectory.states]

    def _initialize_sum_finalstate(self, example_trajectory):
        self.sum_final_state = qzero_like(
            _to_dm(example_trajectory.final_state)
        )

    def reduce_states(self, trajectory, weight=1., td_weight=None):
        """
        Adds the states stored in the given trajectory to the running sum
        `sum_states`. Takes account of the trajectory's total weight if
        present.
        """
        if td_weight is not None:
            self.sum_states = [
                accu + weight * weight_t * _to_dm(state)
                for accu, state, weight_t in zip(
                    self.sum_states, trajectory.states, td_weight
                )
            ]
        else:
            self.sum_states = [
                accu + weight * _to_dm(state)
                for accu, state in zip(self.sum_states, trajectory.states)
            ]

    def reduce_final_state(self, trajectory, weight=1.):
        """
        Adds the final state stored in the given trajectory to the running sum
        `sum_final_state`. Takes account of the trajectory's total weight if
        present.
        """
        self.sum_final_state += weight * _to_dm(trajectory.final_state)

    def reduce_expect(self, trajectory, weight=1.):
        """
        Adds the expectation values, and their squares, that are stored in the
        given trajectory to the running sums `sum_expect` and `sum2_expect`.
        Takes account of the trajectory's total weight if present.
        """
        for i, expect_traj in enumerate(trajectory.expect):
            self.sum_expect[i] += weight * expect_traj
            self.sum2_expect[i] += weight * expect_traj**2

    @staticmethod
    def merge(sum1, weight1, sum2, weight2):
        """
        Merges the sums of expectation values, states and final states with
        the given weights, i.e., `result = weight1 * sum1 + weight2 * sum2`.
        """
        if sum1 is None and sum2 is None:
            return None
        if sum1 is None:
            return _TrajectorySum.merge(sum2, weight2, sum1, weight1)

        new = copy(sum1)

        if sum2 is None:
            if sum1.sum_states:
                new.sum_states = [
                    weight1 * state1 for state1 in sum1.sum_states
                ]
            if sum1.sum_final_state:
                new.sum_final_state = weight1 * sum1.sum_final_state
            new.sum_expect = [weight1 * e1 for e1 in sum1.sum_expect]
            new.sum2_expect = [weight1 * e1 for e1 in sum1.sum2_expect]
            return new

        if sum1.sum_states and sum2.sum_states:
            new.sum_states = [
                weight1 * state1 + weight2 * state2 for state1, state2 in zip(
                    sum1.sum_states, sum2.sum_states
                )
            ]
        else:
            new.sum_states = None

        if sum1.sum_final_state and sum2.sum_final_state:
            new.sum_final_state = (
                weight1 * sum1.sum_final_state +
                weight2 * sum2.sum_final_state)
        else:
            new.sum_final_state = None

        new.sum_expect = [weight1 * e1 + weight2 * e2 for e1, e2 in zip(
            sum1.sum_expect, sum2.sum_expect)
        ]
        new.sum2_expect = [weight1 * e1 + weight2 * e2 for e1, e2 in zip(
            sum1.sum2_expect, sum2.sum2_expect)
        ]

        return new


class _McBaseResult(MultiTrajResult):
    # Collapse are only produced by mcsolve.
    def _add_collapse(self, trajectory, *, rel=None, abs=None):
        if rel is not None:
            self.collapse.append(trajectory.collapse)

    def _post_init(self):
        super()._post_init()
        self.num_c_ops = self.stats["num_collapse"]
        self.collapse = []
        self.add_processor(self._add_collapse)

    @property
    def col_times(self):
        """
        List of the times of the collapses for each runs.
        """
        out = []
        for col_ in self.collapse:
            col = list(zip(*col_))
            col = [] if len(col) == 0 else col[0]
            out.append(col)
        return out

    @property
    def col_which(self):
        """
        List of the indexes of the collapses for each runs.
        """
        out = []
        for col_ in self.collapse:
            col = list(zip(*col_))
            col = [] if len(col) == 0 else col[1]
            out.append(col)
        return out

    def merge(self, other, p=None):
        new = super().merge(other, p)
        new.collapse = self.collapse + other.collapse
        return new


class McResult(_McBaseResult):
    """
    Class for storing Monte-Carlo solver results.

    Attributes
    ----------

    times : list
        A list of the times at which the expectation values and states were
        recorded.

    average_states : list of :obj:`.Qobj`
        The state at each time ``t`` (if the recording of the state was
        requested) averaged over all trajectories as a density matrix.

    runs_states : list of list of :obj:`.Qobj`
        The state for each trajectory and each time ``t`` (if the recording of
        the states and trajectories was requested)

    average_final_state : :obj:`.Qobj`:
        The final state (if the recording of the final state was requested)
        averaged over all trajectories as a density matrix.

    runs_final_states : list of :obj:`.Qobj`
        The final state for each trajectory (if the recording of the final
        state and trajectories was requested).

    average_expect : list of array of expectation values
        A list containing the values of each ``e_op`` averaged over each
        trajectories. The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    std_expect : list of array of expectation values
        A list containing the standard derivation of each ``e_op`` over each
        trajectories. The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    runs_expect : list of array of expectation values
        A list containing the values of each ``e_op`` for each trajectories.
        The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given. Only available if the
        storing of trajectories was requested.

        The order of the elements is ``runs_expect[e_ops][trajectory][time]``.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    average_e_data : dict
        A dictionary containing the values of each ``e_op`` averaged over each
        trajectories. If the ``e_ops`` were supplied as a dictionary, the keys
        are the same as in that dictionary. Otherwise the keys are the index of
        the ``e_op`` in the ``.expect`` list.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    std_e_data : dict
        A dictionary containing the standard derivation of each ``e_op`` over
        each trajectories. If the ``e_ops`` were supplied as a dictionary, the
        keys are the same as in that dictionary. Otherwise the keys are the
        index of the ``e_op`` in the ``.expect`` list.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    runs_e_data : dict
        A dictionary containing the values of each ``e_op`` for each
        trajectories. If the ``e_ops`` were supplied as a dictionary, the keys
        are the same as in that dictionary. Otherwise the keys are the index of
        the ``e_op`` in the ``.expect`` list. Only available if the storing
        of trajectories was requested.

        The order of the elements is ``runs_expect[e_ops][trajectory][time]``.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    num_trajectories: int
        Number of trajectories computed.

    seeds: list of SeedSequence
        The seeds used to compute each trajectories.

    trajectories: list of Result
        If the option `keep_runs_results` is set, a list of all trajectories.

    deterministic_trajectories: list of Result
        A list of the no-jump trajectories if the option ``improved_sampling``
        is set.

    runs_weights : list
        For each trajectory, the weight with which that trajectory enters
        averages.

    deterministic_weights : list
        For each deterministic trajectory, (when using improved_sampling)
        the weight with which that trajectory enters averages.

    collapse : list
        For each run, a list of every collapse as a tuple of the time it
        happened and the corresponding ``c_ops`` index.

    col_times : list
        List of the times of the collapses for each runs.

    col_which : list
        List of the indexes of the collapses for each runs.

    photocurrent : array
        Average photocurrent or measurement of the evolution.

    runs_photocurrent : list[array]
        Photocurrent or measurement of each runs.

    stats : dict or None
        The stats generated by the solver while producing these results.

    solver : str or None
        The name of the solver generating these results.

    options : :obj:`~SolverResultsOptions`
        The options for this result class.
    """

    @property
    def photocurrent(self):
        """
        Average photocurrent or measurement of the evolution.
        """
        collapse_times = [[] for _ in range(self.num_c_ops)]
        collapse_weights = [[] for _ in range(self.num_c_ops)]
        tlist = self.times
        for collapses, weight in zip(self.collapse, self.runs_weights):
            for t, which in collapses:
                collapse_times[which].append(t)
                collapse_weights[which].append(weight)

        mesurement = [
            np.histogram(times, bins=tlist, weights=weights)[0]
            / np.diff(tlist)
            for times, weights in zip(collapse_times, collapse_weights)
        ]
        return mesurement

    @property
    def runs_photocurrent(self):
        """
        Photocurrent or measurement of each runs.
        """
        tlist = self.times
        measurements = []
        for collapses in self.collapse:
            collapse_times = [[] for _ in range(self.num_c_ops)]
            for t, which in collapses:
                collapse_times[which].append(t)
            measurements.append(
                [
                    np.histogram(times, tlist)[0] / np.diff(tlist)
                    for times in collapse_times
                ]
            )
        return measurements


class NmmcResult(_McBaseResult):
    """
    Class for storing the results of the non-Markovian Monte-Carlo solver.

    Attributes
    ----------

    times : list
        A list of the times at which the expectation values and states were
        recorded.

    average_states : list of :obj:`.Qobj`
        The state at each time ``t`` (if the recording of the state was
        requested) averaged over all trajectories as a density matrix.

    runs_states : list of list of :obj:`.Qobj`
        The state for each trajectory and each time ``t`` (if the recording of
        the states and trajectories was requested)

    average_final_state : :obj:`.Qobj`:
        The final state (if the recording of the final state was requested)
        averaged over all trajectories as a density matrix.

    runs_final_states : list of :obj:`.Qobj`
        The final state for each trajectory (if the recording of the final
        state and trajectories was requested).

    average_expect : list of array of expectation values
        A list containing the values of each ``e_op`` averaged over each
        trajectories. The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    std_expect : list of array of expectation values
        A list containing the standard derivation of each ``e_op`` over each
        trajectories. The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    runs_expect : list of array of expectation values
        A list containing the values of each ``e_op`` for each trajectories.
        The list is in the same order in which the ``e_ops`` were
        supplied and empty if no ``e_ops`` were given. Only available if the
        storing of trajectories was requested.

        The order of the elements is ``runs_expect[e_ops][trajectory][time]``.

        Each element is itself an array and contains the values of the
        corresponding ``e_op``, with one value for each time in ``.times``.

    average_e_data : dict
        A dictionary containing the values of each ``e_op`` averaged over each
        trajectories. If the ``e_ops`` were supplied as a dictionary, the keys
        are the same as in that dictionary. Otherwise the keys are the index of
        the ``e_op`` in the ``.expect`` list.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    std_e_data : dict
        A dictionary containing the standard derivation of each ``e_op`` over
        each trajectories. If the ``e_ops`` were supplied as a dictionary, the
        keys are the same as in that dictionary. Otherwise the keys are the
        index of the ``e_op`` in the ``.expect`` list.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    runs_e_data : dict
        A dictionary containing the values of each ``e_op`` for each
        trajectories. If the ``e_ops`` were supplied as a dictionary, the keys
        are the same as in that dictionary. Otherwise the keys are the index of
        the ``e_op`` in the ``.expect`` list. Only available if the storing
        of trajectories was requested.

        The order of the elements is ``runs_expect[e_ops][trajectory][time]``.

        The lists of expectation values returned are the *same* lists as
        those returned by ``.expect``.

    num_trajectories: int
        Number of trajectories computed.

    seeds: list of SeedSequence
        The seeds used to compute each trajectories.

    trajectories: list of Result
        If the option `keep_runs_results` is set, a list of all trajectories.

    deterministic_trajectories: list of Result
        A list of the no-jump trajectories if the option ``improved_sampling``
        is set.

    runs_weights : list
        For each trajectory, the weight with which that trajectory enters
        averages.

    deterministic_weights : list
        For each deterministic trajectory, (when using improved_sampling)
        the weight with which that trajectory enters averages.

    collapse : list
        For each run, a list of every collapse as a tuple of the time it
        happened and the corresponding ``c_ops`` index.

    col_times : list
        List of the times of the collapses for each runs.

    col_which : list
        List of the indexes of the collapses for each runs.

    average_trace : list
        The average trace (i.e., averaged over all trajectories) at each time.

    std_trace : list
        The standard deviation of the trace at each time.

    runs_trace : list of lists
        For each recorded trajectory, the trace at each time.
        Only present if ``keep_runs_results`` is set in the options.

    stats : dict or None
        The stats generated by the solver while producing these results.

    solver : str or None
        The name of the solver generating these results.

    options : :obj:`~SolverResultsOptions`
        The options for this result class.
    """

    def _post_init(self):
        super()._post_init()

        self._sum_trace_det = None
        self._sum_trace_rel = None
        self._sum2_trace_det = None
        self._sum2_trace_rel = None

        self._average_trace = None
        self._std_trace = None
        self.runs_trace = []

        self.add_processor(self._add_trace)

    def _reduce_states(self, trajectory, *, abs=None, rel=None):
        if abs is not None:
            self._sum_det.reduce_states(trajectory, abs, trajectory.trace)
        else:
            self._sum_rel.reduce_states(trajectory, rel, trajectory.trace)

    def _reduce_final_state(self, trajectory, *, abs=None, rel=None):
        if abs is not None:
            self._sum_det.reduce_final_state(
                trajectory, abs * trajectory.trace[-1])
        else:
            self._sum_rel.reduce_final_state(
                trajectory, rel * trajectory.trace[-1])

    def _reduce_expect(self, trajectory, *, abs=None, rel=None):
        """
        Compute the average of the expectation values and store it in it's
        multiple formats.
        """
        if abs is not None:
            self._sum_det.reduce_expect(
                trajectory, abs * np.array(trajectory.trace))
        else:
            self._sum_rel.reduce_expect(
                trajectory, rel * np.array(trajectory.trace))

            if self.runs_e_data:
                for k in self._raw_ops:
                    self.runs_e_data[k].append(trajectory.e_data[k])

    def _add_first_traj(self, trajectory):
        super()._add_first_traj(trajectory)
        self._sum_trace_det = np.zeros_like(trajectory.trace)
        self._sum_trace_rel = np.zeros_like(trajectory.trace)
        self._sum2_trace_det = np.zeros_like(trajectory.trace)
        self._sum2_trace_rel = np.zeros_like(trajectory.trace)

    def _add_trace(self, trajectory, *, abs=None, rel=None):
        if abs is not None:
            self._sum_trace_det += np.array(trajectory.trace) * abs
            self._sum2_trace_det += np.abs(trajectory.trace) ** 2 * abs
        else:
            self._sum_trace_rel += np.array(trajectory.trace) * rel
            self._sum2_trace_rel += np.abs(trajectory.trace) ** 2 * rel

        if self.options["keep_runs_results"]:
            self.runs_trace.append(trajectory.trace)

    def _compute_avg_trace(self):
        avg = self._sum_trace_det
        if self.num_trajectories > 0:
            avg = avg + self._sum_trace_rel / self.num_trajectories
        avg2 = self._sum2_trace_det
        if self.num_trajectories > 0:
            avg2 = avg2 + self._sum2_trace_rel / self.num_trajectories

        self._average_trace = avg
        self._std_trace = np.sqrt(np.abs(avg2 - np.abs(avg) ** 2))

    @property
    def average_trace(self):
        """
        Refers to ``average_trace`` or ``runs_trace``, depending on whether
        ``keep_runs_results`` is set in the options.
        """
        if self._average_trace is None:
            self._compute_avg_trace()
        return self._average_trace

    @property
    def std_trace(self):
        """
        Refers to ``average_trace`` or ``runs_trace``, depending on whether
        ``keep_runs_results`` is set in the options.
        """
        if self._std_trace is None:
            self._compute_avg_trace()
        return self._std_trace

    @property
    def trace(self):
        """
        Refers to ``average_trace`` or ``runs_trace``, depending on whether
        ``keep_runs_results`` is set in the options.
        """
        return self.runs_trace or self.average_trace

    def merge(self, other, p=None):
        new = super().merge(other, p)

        p_eq = self.num_trajectories / new.num_trajectories
        if p is None:
            p = p_eq

        new._sum_trace_det = (
            p * self._sum_trace_det +
            (1 - p) * other._sum_trace_det
        )
        new._sum2_trace_det = (
            p * self._sum2_trace_det +
            (1 - p) * other._sum2_trace_det
        )
        new._sum_trace_rel = (
            (p / p_eq) * self._sum_trace_rel +
            ((1 - p) / (1 - p_eq)) * other._sum_trace_rel
        )
        new._sum2_trace_rel = (
            (p / p_eq) * self._sum2_trace_rel +
            ((1 - p) / (1 - p_eq)) * other._sum2_trace_rel
        )
        new._compute_avg_trace()

        if self.runs_trace and other.runs_trace:
            new.runs_trace = self.runs_trace + other.runs_trace

        return new


def _to_dm(state):
    if state.type == "ket":
        state = state.proj()
    return state
